Sympy grading function

#Dec 24, 2019. Nasser M. Abbasi: 
#              Port of original Maple grading function by 
#              Albert Rich to use with Sympy/Python 
#Dec 27, 2019  Nasser. Added `RootSum`. See problem 177, Timofeev file 
#              added 'exp_polar' 
from sympy import * 
 
def leaf_count(expr): 
    #sympy do not have leaf count function. This is approximation 
    return round(1.7*count_ops(expr)) 
 
def is_sqrt(expr): 
    if isinstance(expr,Pow): 
        if expr.args[1] == Rational(1,2): 
            return True 
        else: 
            return False 
    else: 
        return False 
 
def is_elementary_function(func): 
    return func in [exp,log,ln,sin,cos,tan,cot,sec,csc, 
            asin,acos,atan,acot,asec,acsc,sinh,cosh,tanh,coth,sech,csch, 
            asinh,acosh,atanh,acoth,asech,acsch 
        ] 
 
def is_special_function(func): 
    return func in [ erf,erfc,erfi, 
             fresnels,fresnelc,Ei,Ei,Li,Si,Ci,Shi,Chi, 
             gamma,loggamma,digamma,zeta,polylog,LambertW, 
             elliptic_f,elliptic_e,elliptic_pi,exp_polar 
         ] 
 
def is_hypergeometric_function(func): 
    return func in [hyper] 
 
def is_appell_function(func): 
    return func in [appellf1] 
 
def is_atom(expn): 
    try: 
        if expn.isAtom or isinstance(expn,int) or isinstance(expn,float): 
           return True 
        else: 
           return False 
 
    except AttributeError as error: 
        return False 
 
def expnType(expn): 
    debug=False 
    if debug: 
        print("expn=",expn,"type(expn)=",type(expn)) 
 
    if is_atom(expn): 
       return 1 
    elif isinstance(expn,list): 
        return max(map(expnType, expn))   #apply(max,map(ExpnType,expn)) 
    elif  is_sqrt(expn): 
        if isinstance(expn.args[0],Rational): #type(op(1,expn),'rational') 
            return 1 
        else: 
            return max(2,expnType(expn.args[0]))  #max(2,ExpnType(op(1,expn))) 
    elif isinstance(expn,Pow):   #type(expn,'`^`') 
        if isinstance(expn.args[1],Integer):  #type(op(2,expn),'integer') 
            return expnType(expn.args[0])   #ExpnType(op(1,expn)) 
        elif isinstance(expn.args[1],Rational):  #type(op(2,expn),'rational') 
            if isinstance(expn.args[0],Rational): #type(op(1,expn),'rational') 
                return 1 
            else: 
                return max(2,expnType(expn.args[0]))  #max(2,ExpnType(op(1,expn))) 
        else: 
            return max(3,expnType(expn.args[0]),expnType(expn.args[1])) #max(3,ExpnType(op(1,expn)),ExpnType(op(2,expn))) 
    elif isinstance(expn,Add) or isinstance(expn,Mul): #type(expn,'`+`') or type(expn,'`*`') 
        m1 = expnType(expn.args[0]) 
        m2 = expnType(list(expn.args[1:])) 
        return max(m1,m2)  #max(ExpnType(op(1,expn)),max(ExpnType(rest(expn)))) 
    elif is_elementary_function(expn.func):  #ElementaryFunctionQ(op(0,expn)) 
        return max(3,expnType(expn.args[0]))  #max(3,ExpnType(op(1,expn))) 
    elif is_special_function(expn.func): #SpecialFunctionQ(op(0,expn)) 
        m1 = max(map(expnType, list(expn.args))) 
        return max(4,m1)   #max(4,apply(max,map(ExpnType,[op(expn)]))) 
    elif is_hypergeometric_function(expn.func): #HypergeometricFunctionQ(op(0,expn)) 
        m1 = max(map(expnType, list(expn.args))) 
        return max(5,m1)   #max(5,apply(max,map(ExpnType,[op(expn)]))) 
    elif is_appell_function(expn.func): 
        m1 = max(map(expnType, list(expn.args))) 
        return max(6,m1)   #max(5,apply(max,map(ExpnType,[op(expn)]))) 
    elif isinstance(expn,RootSum): 
        m1 = max(map(expnType, list(expn.args))) #Apply[Max,Append[Map[ExpnType,Apply[List,expn]],7]], 
        return max(7,m1) 
    elif str(expn).find("Integral") != -1: 
        m1 = max(map(expnType, list(expn.args))) 
        return max(8,m1)   #max(5,apply(max,map(ExpnType,[op(expn)]))) 
    else: 
        return 9 
 
#main function 
def grade_antiderivative(result,optimal): 
 
    #print ("Enter grade_antiderivative for sagemath") 
    #print("Enter grade_antiderivative, result=",result," optimal=",optimal) 
 
    leaf_count_result  = leaf_count(result) 
    leaf_count_optimal = leaf_count(optimal) 
 
    #print("leaf_count_result=",leaf_count_result) 
    #print("leaf_count_optimal=",leaf_count_optimal) 
 
    expnType_result  = expnType(result) 
    expnType_optimal = expnType(optimal) 
 
    if str(result).find("Integral") != -1: 
        grade = "F" 
        grade_annotation ="" 
    else: 
        if expnType_result <= expnType_optimal: 
            if result.has(I): 
                if optimal.has(I): #both result and optimal complex 
                    if leaf_count_result <= 2*leaf_count_optimal: 
                        grade = "A" 
                        grade_annotation ="" 
                    else: 
                        grade = "B" 
                        grade_annotation ="Both result and optimal contain complex but leaf count of result is larger than twice the leaf count of optimal. "+str(leaf_count_result)+" vs. $2 ("+str(leaf_count_optimal)+") = "+ str(2*leaf_count_optimal)+"$." 
                else: #result contains complex but optimal is not 
                    grade = "C" 
                    grade_annotation ="Result contains complex when optimal does not." 
            else: # result do not contain complex, this assumes optimal do not as well 
                if leaf_count_result <= 2*leaf_count_optimal: 
                    grade = "A" 
                    grade_annotation ="" 
                else: 
                    grade = "B" 
                    grade_annotation ="Leaf count of result is larger than twice the leaf count of optimal. "+str(leaf_count_result)+" vs. $2 ("+str(leaf_count_optimal)+") = "+ str(2*leaf_count_optimal)+"$." 
        else: 
            grade = "C" 
            grade_annotation ="Result contains higher order function than in optimal. Order "+str(ExpnType_result)+" vs. order "+str(ExpnType_optimal)+"." 
 
 
    #print("Before returning. grade=",grade, " grade_annotation=",grade_annotation) 
 
    return grade, grade_annotation